#!/usr/bin/env python3
# policy/belief_manager.py
# 信念管理：初始化、传播、更新、归一化
import numpy as np
from math import dist

import logging

logger = logging.getLogger(__name__)


class BeliefManager:
    def __init__(
        self,
        node_num,
        adj_matrix,
        obs_dist=4,
        obs_nodes=None,
        pursuer_mode=1,
        node_distance=3.0,
        max_speed=2.0,
        frame_interval=0.2,  # 新增物理参数
    ):
        self.node_num = node_num
        self.adj_matrix = adj_matrix
        self.obs_dist = obs_dist
        self.obs_nodes = np.array(obs_nodes or [])
        self.belief_pr = np.zeros(node_num)
        self.belief_pos = np.zeros(node_num, dtype=bool)
        self.pursuer_mode = pursuer_mode  # 0: reset, 1: preserve

        # 信念传播速度控制
        self.node_distance = node_distance
        self.max_speed = max_speed
        self.frame_interval = frame_interval
        self.propagation_time = node_distance / max_speed  # 覆盖一个节点的时间 (s)
        self.propagation_frames = max(
            1, int(self.propagation_time / frame_interval)
        )  # 传播间隔帧数，至少1
        self.frame_counter = 0  # 帧计数器

    def init_belief(self, evader_node):
        """初始化信念到evader_node"""
        self.belief_pr.fill(0.0)
        self.belief_pos.fill(False)
        if 0 <= evader_node < self.node_num:
            self.belief_pr[evader_node] = 1.0
            self.belief_pos[evader_node] = True

    def expand_belief(self):
        """传播信念：模拟随机游走"""
        if np.sum(self.belief_pos) == 0:
            return
        temp = np.zeros(self.node_num)
        epsilon = 1e-10
        stay_prob = 0.5
        move_prob = 0.5

        for i in range(self.node_num):
            if self.belief_pos[i]:
                pr_i = self.belief_pr[i]
                deg_i = np.sum(self.adj_matrix[i])
                if deg_i > 0:
                    temp[i] += pr_i * stay_prob
                    contrib = pr_i * move_prob / (deg_i + epsilon)
                    for neigh in self.get_neighbors(i):
                        temp[neigh] += contrib
                else:
                    temp[i] += pr_i
        self.belief_pr = np.clip(temp, 0, 1)
        self.belief_pos = self.belief_pr > 1e-8

    def get_neighbors(self, node_id):
        """获取邻居"""
        neighbors = []
        for j in range(self.node_num):
            if self.adj_matrix[node_id][j] == 1:
                neighbors.append(j)
        return neighbors

    def remove_visible(self, p1_node, p2_node, f_matrix):
        """移除从p1/p2可见的位置"""
        for l in range(self.node_num):
            if self.visible(p1_node, p2_node, l, f_matrix):
                self.belief_pos[l] = False
                self.belief_pr[l] = 0.0

    def visible(self, p1, p2, l, f_matrix):
        """检查l是否可见"""
        if f_matrix[p1][l] <= self.obs_dist:
            return True
        if f_matrix[p2][l] <= self.obs_dist:
            return True
        for obs in self.obs_nodes:
            if obs < self.node_num and f_matrix[obs][l] <= self.obs_dist:
                return True
        return False

    def normalize_belief(self):
        """归一化信念"""
        total = np.sum(self.belief_pr)
        if total > 0:
            self.belief_pr /= total

    def update_belief(
        self, evader_node, p1_node, p2_node, f_matrix, visible_after_move=False
    ):
        """完整更新：(条件)传播 -> 移除 -> (if mode) 重置 -> 归一化"""
        self.frame_counter += 1  # 每调用update_belief，计数+1（假设每帧调用一次）
        if self.frame_counter >= self.propagation_frames:
            self.expand_belief()  # 只在间隔帧数时传播
            self.frame_counter = 0  # 重置计数器

        self.remove_visible(p1_node, p2_node, f_matrix)
        if self.pursuer_mode == 0 or visible_after_move:
            self.init_belief(evader_node)
        self.normalize_belief()

    def get_belief_pr(self):
        return self.belief_pr.copy()

    def get_belief_pos(self):
        return self.belief_pos.copy()

    def update_propagation_params(
        self, node_distance=None, max_speed=None, frame_interval=None
    ):
        """动态更新物理参数并重新计算间隔"""
        if node_distance is not None:
            self.node_distance = node_distance
        if max_speed is not None:
            self.max_speed = max_speed
        if frame_interval is not None:
            self.frame_interval = frame_interval
        self.propagation_time = self.node_distance / self.max_speed
        self.propagation_frames = max(
            1, int(self.propagation_time / self.frame_interval)
        )
        self.frame_counter = 0  # 重置计数器
        logger.info(
            f"信念传播参数更新: 间隔 = {self.propagation_frames} 帧 (时间={self.propagation_time:.2f}s)"
        )
